Skip to content

Add associative scan#30

Merged
KaelanDt merged 19 commits into
mainfrom
associative_scan
Jun 5, 2024
Merged

Add associative scan#30
KaelanDt merged 19 commits into
mainfrom
associative_scan

Conversation

@SamDuffield

Copy link
Copy Markdown
Contributor

First attempt at using jax.lax.associative_scan #14 , but it's throwing a matmul contracting dimensions error and I'm not sure why.

@AdrienCorenflos AdrienCorenflos left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Associative scan takes vectorised operators (given that the operations across the leaves of the computational tree are batched).

Comment thread thermox/sampler.py Outdated
@SamDuffield

Copy link
Copy Markdown
Contributor Author

Update: associative_scan now working but seems like something is wrong with the calculations so I need to check the maths again

@SamDuffield

Copy link
Copy Markdown
Contributor Author

Ok I fixed the maths! At the cost of doubling the number of expm_vp calls, we might be able to halve it again with further thought although I'm not sure.

Next step is to add associative_scan for log_prob

@SamDuffield SamDuffield marked this pull request as ready for review June 3, 2024 12:50
Comment thread tests/test_sampler.py
Comment thread thermox/sampler.py Outdated
@KaelanDt

KaelanDt commented Jun 4, 2024

Copy link
Copy Markdown
Contributor

Finished the speedup comparison with and without associative scan and adapted the handling of random keys in _sample_identity_diffusion, should be good to review

@SamDuffield

Copy link
Copy Markdown
Contributor Author

Be sure to add the underscore to sample_identity_diffusion to become _sample_identity_diffusion and maybe remove the docstring too

@KaelanDt KaelanDt left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@KaelanDt KaelanDt merged commit 3446bd4 into main Jun 5, 2024
@KaelanDt KaelanDt deleted the associative_scan branch June 5, 2024 12:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants